{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2d479b4a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-19T07:47:51.787237Z",
     "start_time": "2022-09-19T07:47:51.778516Z"
    }
   },
   "source": [
    "Read data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4b06eeb4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-19T08:09:42.796617Z",
     "start_time": "2022-09-19T08:09:42.434654Z"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append(\"E:/study_e/交通事故安全/code/mgtwr/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7d65c606",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-19T08:09:43.559791Z",
     "start_time": "2022-09-19T08:09:43.539778Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>longitude</th>\n",
       "      <th>latitude</th>\n",
       "      <th>t</th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.771321</td>\n",
       "      <td>0.895098</td>\n",
       "      <td>10.656550</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.020752</td>\n",
       "      <td>0.633729</td>\n",
       "      <td>5.692754</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   longitude  latitude  t        x1        x2          y\n",
       "0          0         0  0  0.771321  0.895098  10.656550\n",
       "1          1         0  0  0.020752  0.633729   5.692754"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = pd.read_csv(r'E:\\study_e\\交通事故安全\\code\\mgtwr\\data\\example.csv')\n",
    "data.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "796fe3e5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-19T08:09:44.798946Z",
     "start_time": "2022-09-19T08:09:44.787275Z"
    }
   },
   "outputs": [],
   "source": [
    "coords = data[['longitude', 'latitude']]\n",
    "t = data[['t']]\n",
    "X = data[['x1', 'x2']]\n",
    "y = data[['y']]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "905c6002",
   "metadata": {},
   "source": [
    "GWR model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "532ffaf3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-19T08:09:50.207629Z",
     "start_time": "2022-09-19T08:09:50.131373Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from mgtwr_local.sel import SearchGWRParameter\n",
    "from mgtwr_local.model import GWR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e0aa3e40",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-19T08:09:54.705874Z",
     "start_time": "2022-09-19T08:09:53.355340Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bw: 15.0 , score: 18778.49\n",
      "bw: 10.0 , score: 18764.75\n",
      "bw: 6.0 , score: 18699.21\n",
      "bw: 4.0 , score: 18506.22\n",
      "bw: 2.0 , score: 17786.86\n",
      "bw: 2.0 , score: 17786.86\n",
      "time cost: 0:00:3.111\n"
     ]
    }
   ],
   "source": [
    "sel = SearchGWRParameter(coords, X, y, kernel='gaussian', fixed=True)\n",
    "bw = sel.search(bw_max=40, verbose=True, time_cost=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "cb3be837",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-19T08:10:32.986328Z",
     "start_time": "2022-09-19T08:10:32.709532Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5935790327518\n"
     ]
    }
   ],
   "source": [
    "gwr = GWR(coords, X, y, bw, kernel='gaussian', fixed=True).fit()\n",
    "print(gwr.R2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "4e8e3527",
   "metadata": {},
   "outputs": [],
   "source": [
    "def summaryGWR(Results):\n",
    "    XNames = [\"X\"+str(i) for i in range(Results.k)]\n",
    "    \n",
    "    summary = \"%s\\n\" %('Geographically weighted regression (GWR) Results')\n",
    "    summary += '-' * 75 + '\\n'\n",
    "\n",
    "    if Results.fixed:\n",
    "        summary += \"%-50s %20s\\n\" % ('Spatial kernel:', 'Fixed ' + Results.kernel)\n",
    "    else:\n",
    "        summary += \"%-54s %20s\\n\" % ('Spatial kernel:', 'Adaptive ' + Results.kernel)\n",
    "    \n",
    "    summary += \"%-62s %12.3f\\n\" % ('Temporal Bandwidth used:',  Results.bw)\n",
    "\n",
    "    summary += \"\\n%s\\n\" % ('Diagnostic information')\n",
    "    summary += '-' * 75 + '\\n'\n",
    "    \n",
    "    if Results.kernel == 'gaussian':\n",
    "        \n",
    "        summary += \"%-62s %12.3f\\n\" % ('Residual sum of squares:', Results.RSS)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('Effective number of parameters (trace(S)):', Results.tr_S)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('Degree of freedom (n - trace(S)):', Results.df_model)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('Sigma estimate:', np.sqrt(Results.sigma2))\n",
    "        summary += \"%-62s %12.3f\\n\" % ('Log-likelihood:', Results.llf)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('AIC:', Results.aic)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('AICc:', Results.aicc)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('BIC:', Results.bic)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('R2:', Results.R2)\n",
    "    else:\n",
    "        summary += \"%-62s %12.3f\\n\" % ('Effective number of parameters (trace(S)):', Results.tr_S)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('Degree of freedom (n - trace(S)):', Results.df_model)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('Log-likelihood:', Results.llf)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('AIC:', Results.aic)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('AICc:', Results.aicc)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('BIC:', Results.bic)\n",
    "        #summary += \"%-60s %12.6f\\n\" % ('Percent deviance explained:', 0)\n",
    "\n",
    "\n",
    "    summary += \"%-62s %12.3f\\n\" % ('Adj. alpha (95%):', Results.adj_alpha[1])\n",
    "    summary += \"%-62s %12.3f\\n\" % ('Adj. critical t value (95%):', Results.critical_tval(Results.adj_alpha[1]))\n",
    "\n",
    "    summary += \"\\n%s\\n\" % ('Summary Statistics For GWR Parameter Estimates')\n",
    "    summary += '-' * 75 + '\\n'\n",
    "    summary += \"%-20s %10s %10s %10s %10s %10s\\n\" % ('Variable', 'Mean' ,'STD', 'Min' ,'Median', 'Max')\n",
    "    summary += \"%-20s %10s %10s %10s %10s %10s\\n\" % ('-'*20, '-'*10 ,'-'*10, '-'*10 ,'-'*10, '-'*10)\n",
    "    for i in range(Results.k):\n",
    "        summary += \"%-20s %10.3f %10.3f %10.3f %10.3f %10.3f\\n\" % (XNames[i], np.mean(Results.betas[:,i]) ,np.std(Results.betas[:,i]),np.min(Results.betas[:,i]) ,np.median(Results.betas[:,i]), np.max(Results.betas[:,i]))\n",
    "\n",
    "    summary += '=' * 75 + '\\n'\n",
    "\n",
    "    return summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "8de90267",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Geographically weighted regression (GWR) Results\n",
      "---------------------------------------------------------------------------\n",
      "Spatial kernel:                                          Fixed gaussian\n",
      "Temporal Bandwidth used:                                              2.000\n",
      "\n",
      "Diagnostic information\n",
      "---------------------------------------------------------------------------\n",
      "Residual sum of squares:                                        2901501.328\n",
      "Effective number of parameters (trace(S)):                           24.069\n",
      "Degree of freedom (n - trace(S)):                                  1703.931\n",
      "Sigma estimate:                                                      41.265\n",
      "Log-likelihood:                                                   -8868.006\n",
      "AIC:                                                              17786.151\n",
      "AICc:                                                             17786.859\n",
      "BIC:                                                              17765.831\n",
      "R2:                                                                   0.594\n",
      "Adj. alpha (95%):                                                     0.006\n",
      "Adj. critical t value (95%):                                          2.739\n",
      "\n",
      "Summary Statistics For GWR Parameter Estimates\n",
      "---------------------------------------------------------------------------\n",
      "Variable                   Mean        STD        Min     Median        Max\n",
      "-------------------- ---------- ---------- ---------- ---------- ----------\n",
      "X0                        2.659      3.131     -2.591      2.269     10.003\n",
      "X1                       10.028      5.455     -2.504     11.049     18.555\n",
      "X2                      117.226     42.589     33.553    115.155    200.927\n",
      "===========================================================================\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(summaryGWR(gwr))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac6e9d39",
   "metadata": {},
   "source": [
    "MGWR model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "20f580b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mgtwr_local.sel import SearchMGWRParameter\n",
    "from mgtwr_local.model import MGWR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "08bf65d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Current iteration: 1 ,SOC: 0.0033171\n",
      "Bandwidths: 986.8, 965.5, 0.7\n",
      "Current iteration: 2 ,SOC: 5.64e-05\n",
      "Bandwidths: 986.8, 986.8, 0.7\n",
      "Current iteration: 3 ,SOC: 4.27e-05\n",
      "Bandwidths: 986.8, 986.8, 0.7\n",
      "Current iteration: 4 ,SOC: 3.22e-05\n",
      "Bandwidths: 986.8, 986.8, 0.7\n",
      "Current iteration: 5 ,SOC: 2.43e-05\n",
      "Bandwidths: 986.8, 986.8, 0.7\n",
      "time cost: 0:01:20.431\n"
     ]
    }
   ],
   "source": [
    "sel_multi = SearchMGWRParameter(coords, X, y, kernel='gaussian', fixed=True)\n",
    "bws = sel_multi.search(multi_bw_max=[1000], verbose=True, time_cost=True, tol_multi=3.0e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e7dbf9fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7045779853867871\n"
     ]
    }
   ],
   "source": [
    "mgwr = MGWR(coords, X, y, sel_multi, kernel='gaussian', fixed=True).fit()\n",
    "print(mgwr.R2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68c915f1",
   "metadata": {},
   "source": [
    "If you already know bws, you can also do the following"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "56555609",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7045779853867871\n"
     ]
    }
   ],
   "source": [
    "class sel_multi:\n",
    "    def __init__(self, bws):\n",
    "        self.bws = bws\n",
    "\n",
    "        \n",
    "selector = sel_multi(bws)\n",
    "mgwr = MGWR(coords, X, y, selector, kernel='gaussian', fixed=True).fit()\n",
    "print(mgwr.R2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6aea4108",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-19T08:11:21.337967Z",
     "start_time": "2022-09-19T08:11:21.326547Z"
    }
   },
   "source": [
    "GTWR model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "462da66a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-19T08:11:53.026336Z",
     "start_time": "2022-09-19T08:11:53.021405Z"
    }
   },
   "outputs": [],
   "source": [
    "from mgtwr_local.sel import SearchGTWRParameter\n",
    "from mgtwr_local.model import GTWR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4f9cc821",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-19T08:14:07.489058Z",
     "start_time": "2022-09-19T08:13:28.866324Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bw:  5.9 , tau:  19.9 , score:  18095.04059255282\n",
      "bw:  3.7 , tau:  19.9 , score:  17608.38596885707\n",
      "bw:  2.3 , tau:  10.1 , score:  16461.58709937909\n",
      "bw:  1.4 , tau:  3.8 , score:  14817.811620052908\n",
      "bw:  0.9 , tau:  1.4 , score:  13780.792562049754\n",
      "bw:  0.9 , tau:  1.4 , score:  13780.792562049754\n",
      "bw:  0.9 , tau:  1.4 , score:  13780.792562049754\n",
      "bw:  0.9 , tau:  1.4 , score:  13780.792562049754\n",
      "bw:  0.9 , tau:  1.4 , score:  13780.792562049754\n",
      "time cost: 0:01:40.489\n"
     ]
    }
   ],
   "source": [
    "sel = SearchGTWRParameter(coords, t, X, y, kernel='gaussian', fixed=True)\n",
    "bw, tau = sel.search(tau_max=20, verbose=True, time_cost=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4bbf93f8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-19T08:14:17.776587Z",
     "start_time": "2022-09-19T08:14:17.313360Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9829884630503501\n"
     ]
    }
   ],
   "source": [
    "gtwr = GTWR(coords, t, X, y, bw, tau, kernel='gaussian', fixed=True).fit()\n",
    "print(gtwr.R2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "078a7154",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.771321</td>\n",
       "      <td>0.895098</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.020752</td>\n",
       "      <td>0.633729</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.633648</td>\n",
       "      <td>0.462768</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.748804</td>\n",
       "      <td>0.090788</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.498507</td>\n",
       "      <td>0.982153</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1723</th>\n",
       "      <td>0.289875</td>\n",
       "      <td>0.741360</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1724</th>\n",
       "      <td>0.921317</td>\n",
       "      <td>0.962916</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1725</th>\n",
       "      <td>0.222337</td>\n",
       "      <td>0.055764</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1726</th>\n",
       "      <td>0.633053</td>\n",
       "      <td>0.825171</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1727</th>\n",
       "      <td>0.057299</td>\n",
       "      <td>0.872990</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1728 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "            x1        x2\n",
       "0     0.771321  0.895098\n",
       "1     0.020752  0.633729\n",
       "2     0.633648  0.462768\n",
       "3     0.748804  0.090788\n",
       "4     0.498507  0.982153\n",
       "...        ...       ...\n",
       "1723  0.289875  0.741360\n",
       "1724  0.921317  0.962916\n",
       "1725  0.222337  0.055764\n",
       "1726  0.633053  0.825171\n",
       "1727  0.057299  0.872990\n",
       "\n",
       "[1728 rows x 2 columns]"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "e8421e26",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1.        , 0.77132064, 0.89509824],\n",
       "       [1.        , 0.02075195, 0.63372887],\n",
       "       [1.        , 0.63364823, 0.46276849],\n",
       "       ...,\n",
       "       [1.        , 0.22233723, 0.05576403],\n",
       "       [1.        , 0.63305281, 0.82517067],\n",
       "       [1.        , 0.05729903, 0.87299008]])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gtwr.X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "db6ae181",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1728, 3)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gtwr.X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "6f126dac",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 4.61469855,  5.45243831,  2.99766617],\n",
       "       [ 2.32808632,  4.72600524,  8.75860259],\n",
       "       [-0.52334178,  5.58047125, 16.77459138],\n",
       "       ...,\n",
       "       [ 8.90916936, -7.28382225, 50.68065614],\n",
       "       [ 8.44766537, 10.24180477, 24.51544434],\n",
       "       [10.99289026, 12.95482629, 12.81980239]])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gtwr.betas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "2cfc60da",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1728, 3)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gtwr.betas.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "4a58420c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['const', 'x1', 'x2']"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "columns_list = list(X.columns)\n",
    "columns_list.insert(0,'const')\n",
    "columns_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "b53bceba",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>const</th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4.614699</td>\n",
       "      <td>5.452438</td>\n",
       "      <td>2.997666</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2.328086</td>\n",
       "      <td>4.726005</td>\n",
       "      <td>8.758603</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>-0.523342</td>\n",
       "      <td>5.580471</td>\n",
       "      <td>16.774591</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3.704357</td>\n",
       "      <td>2.380499</td>\n",
       "      <td>12.812832</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>6.740300</td>\n",
       "      <td>1.208885</td>\n",
       "      <td>6.016846</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1723</th>\n",
       "      <td>7.839976</td>\n",
       "      <td>-3.918450</td>\n",
       "      <td>59.958556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1724</th>\n",
       "      <td>10.822673</td>\n",
       "      <td>-17.987315</td>\n",
       "      <td>62.384594</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1725</th>\n",
       "      <td>8.909169</td>\n",
       "      <td>-7.283822</td>\n",
       "      <td>50.680656</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1726</th>\n",
       "      <td>8.447665</td>\n",
       "      <td>10.241805</td>\n",
       "      <td>24.515444</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1727</th>\n",
       "      <td>10.992890</td>\n",
       "      <td>12.954826</td>\n",
       "      <td>12.819802</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1728 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "          const         x1         x2\n",
       "0      4.614699   5.452438   2.997666\n",
       "1      2.328086   4.726005   8.758603\n",
       "2     -0.523342   5.580471  16.774591\n",
       "3      3.704357   2.380499  12.812832\n",
       "4      6.740300   1.208885   6.016846\n",
       "...         ...        ...        ...\n",
       "1723   7.839976  -3.918450  59.958556\n",
       "1724  10.822673 -17.987315  62.384594\n",
       "1725   8.909169  -7.283822  50.680656\n",
       "1726   8.447665  10.241805  24.515444\n",
       "1727  10.992890  12.954826  12.819802\n",
       "\n",
       "[1728 rows x 3 columns]"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame(gtwr.betas , columns= columns_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9bd527a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def summaryGTWR(Results):\n",
    "    XNames = [\"X\"+str(i) for i in range(Results.k)]\n",
    "    \n",
    "    summary = \"%s\\n\" %('Spatiotemporal Weighted Regression (GTWR) Results')\n",
    "    summary += '-' * 75 + '\\n'\n",
    "\n",
    "    if Results.fixed:\n",
    "        summary += \"%-50s %20s\\n\" % ('Spatial kernel:', 'Fixed ' + Results.kernel)\n",
    "    else:\n",
    "        summary += \"%-54s %20s\\n\" % ('Spatial kernel:', 'Adaptive ' + Results.kernel)\n",
    "\n",
    "    summary += \"%-62s %12.3f\\n\" % ('Model tau used:', Results.tau)\n",
    "        \n",
    "    summary += \"%-62s %12.3f\\n\" % ('Spatial Bandwidth used:', Results.bw_s)\n",
    "    \n",
    "    summary += \"%-62s %12.3f\\n\" % ('Temporal Bandwidth used:',  Results.bw_t)\n",
    "\n",
    "    summary += \"\\n%s\\n\" % ('Diagnostic information')\n",
    "    summary += '-' * 75 + '\\n'\n",
    "    \n",
    "    if Results.kernel == 'gaussian':\n",
    "        \n",
    "        summary += \"%-62s %12.3f\\n\" % ('Residual sum of squares:', Results.RSS)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('Effective number of parameters (trace(S)):', Results.tr_S)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('Degree of freedom (n - trace(S)):', Results.df_model)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('Sigma estimate:', np.sqrt(Results.sigma2))\n",
    "        summary += \"%-62s %12.3f\\n\" % ('Log-likelihood:', Results.llf)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('AIC:', Results.aic)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('AICc:', Results.aicc)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('BIC:', Results.bic)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('R2:', Results.R2)\n",
    "    else:\n",
    "        summary += \"%-62s %12.3f\\n\" % ('Effective number of parameters (trace(S)):', Results.tr_S)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('Degree of freedom (n - trace(S)):', Results.df_model)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('Log-likelihood:', Results.llf)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('AIC:', Results.aic)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('AICc:', Results.aicc)\n",
    "        summary += \"%-62s %12.3f\\n\" % ('BIC:', Results.bic)\n",
    "        #summary += \"%-60s %12.6f\\n\" % ('Percent deviance explained:', 0)\n",
    "\n",
    "\n",
    "    summary += \"%-62s %12.3f\\n\" % ('Adj. alpha (95%):', Results.adj_alpha[1])\n",
    "    summary += \"%-62s %12.3f\\n\" % ('Adj. critical t value (95%):', Results.critical_tval(Results.adj_alpha[1]))\n",
    "\n",
    "    summary += \"\\n%s\\n\" % ('Summary Statistics For GTWR Parameter Estimates')\n",
    "    summary += '-' * 75 + '\\n'\n",
    "    summary += \"%-20s %10s %10s %10s %10s %10s\\n\" % ('Variable', 'Mean' ,'STD', 'Min' ,'Median', 'Max')\n",
    "    summary += \"%-20s %10s %10s %10s %10s %10s\\n\" % ('-'*20, '-'*10 ,'-'*10, '-'*10 ,'-'*10, '-'*10)\n",
    "    for i in range(Results.k):\n",
    "        summary += \"%-20s %10.3f %10.3f %10.3f %10.3f %10.3f\\n\" % (XNames[i], np.mean(Results.betas[:,i]) ,np.std(Results.betas[:,i]),np.min(Results.betas[:,i]) ,np.median(Results.betas[:,i]), np.max(Results.betas[:,i]))\n",
    "\n",
    "    summary += '=' * 75 + '\\n'\n",
    "\n",
    "    return summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "522ef54c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Spatiotemporal Weighted Regression (STWR) Results\n",
      "---------------------------------------------------------------------------\n",
      "Spatial kernel:                                          Fixed gaussian\n",
      "Model tau used:                                                       1.400\n",
      "Spatial Bandwidth used:                                               0.900\n",
      "Temporal Bandwidth used:                                              0.761\n",
      "\n",
      "Diagnostic information\n",
      "---------------------------------------------------------------------------\n",
      "Residual sum of squares:                                         121447.959\n",
      "Effective number of parameters (trace(S)):                          529.125\n",
      "Degree of freedom (n - trace(S)):                                  1198.875\n",
      "Sigma estimate:                                                      10.065\n",
      "Log-likelihood:                                                   -6126.104\n",
      "AIC:                                                              13312.459\n",
      "AICc:                                                             13780.793\n",
      "BIC:                                                              12282.027\n",
      "R2:                                                                   0.983\n",
      "Adj. alpha (95%):                                                     0.000\n",
      "Adj. critical t value (95%):                                          3.637\n",
      "\n",
      "Summary Statistics For GTWR Parameter Estimates\n",
      "---------------------------------------------------------------------------\n",
      "Variable                   Mean        STD        Min     Median        Max\n",
      "-------------------- ---------- ---------- ---------- ---------- ----------\n",
      "X0                        4.504      9.645    -32.841      5.187     48.965\n",
      "X1                        6.362     14.354    -49.833      5.397     62.056\n",
      "X2                      111.738     85.144     -2.873     91.440    341.120\n",
      "===========================================================================\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(summaryGTWR(gtwr) )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ad9399f",
   "metadata": {},
   "source": [
    "MGTWR model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "7d015f1a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-19T08:15:02.313810Z",
     "start_time": "2022-09-19T08:15:02.303789Z"
    }
   },
   "outputs": [],
   "source": [
    "from mgtwr_local.sel import SearchMGTWRParameter\n",
    "from mgtwr_local.model import MGTWR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "94d738b5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-19T08:23:08.330524Z",
     "start_time": "2022-09-19T08:15:42.813827Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Current iteration: 1 ,SOC: 0.0025274\n",
      "Bandwidths: 0.7, 0.7, 0.5\n",
      "taus: 1.3,0.8,0.8\n",
      "Current iteration: 2 ,SOC: 0.0011033\n",
      "Bandwidths: 0.9, 0.7, 0.5\n",
      "taus: 3.0,0.4,0.8\n",
      "Current iteration: 3 ,SOC: 0.0005365\n",
      "Bandwidths: 0.9, 0.7, 0.5\n",
      "taus: 3.4,0.2,0.8\n",
      "Current iteration: 4 ,SOC: 0.0003\n",
      "Bandwidths: 0.9, 0.7, 0.5\n",
      "taus: 3.4,0.2,0.8\n",
      "Current iteration: 5 ,SOC: 0.0001986\n",
      "Bandwidths: 0.9, 0.7, 0.5\n",
      "taus: 3.6,0.2,0.8\n",
      "Current iteration: 6 ,SOC: 0.0001415\n",
      "Bandwidths: 0.9, 0.7, 0.5\n",
      "taus: 3.6,0.2,0.8\n",
      "Current iteration: 7 ,SOC: 0.0001052\n",
      "Bandwidths: 0.9, 0.7, 0.5\n",
      "taus: 3.6,0.2,0.8\n",
      "Current iteration: 8 ,SOC: 7.99e-05\n",
      "Bandwidths: 0.9, 0.7, 0.5\n",
      "taus: 3.6,0.2,0.8\n",
      "time cost: 0:19:26.941\n"
     ]
    }
   ],
   "source": [
    "sel_multi = SearchMGTWRParameter(coords, t, X, y, kernel='gaussian', fixed=True)\n",
    "bws = sel_multi.search(multi_bw_min=[0.1], verbose=True, tol_multi=1.0e-4, time_cost=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "51401611",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-19T08:24:31.131209Z",
     "start_time": "2022-09-19T08:24:16.718379Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9972924820674222\n"
     ]
    }
   ],
   "source": [
    "mgtwr = MGTWR(coords, t, X, y, sel_multi, kernel='gaussian', fixed=True).fit()\n",
    "print(mgtwr.R2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "541bdcce",
   "metadata": {},
   "source": [
    "If you already know bws, you can also do the following"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "bcfc1992",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-19T08:25:21.934146Z",
     "start_time": "2022-09-19T08:25:08.333204Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9972924820674222\n"
     ]
    }
   ],
   "source": [
    "class sel_multi:\n",
    "    def __init__(self, bws):\n",
    "        self.bws = bws\n",
    "\n",
    "        \n",
    "selector = sel_multi(bws)\n",
    "mgtwr = MGTWR(coords, t, X, y, selector, kernel='gaussian', fixed=True).fit()\n",
    "print(mgtwr.R2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0534878",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
